import pyopencl as cl
import numpy as np
import math

# -------------------------------
# Constants
# -------------------------------
PHI = 1.6180339887
SQRT_PHI = math.sqrt(PHI)
INSTANCES = 8
SLOTS_PER_INSTANCE = 32
EVOLUTION_TICKS = 100

# -------------------------------
# Platform / Device / Context
# -------------------------------
platforms = cl.get_platforms()
device = platforms[0].get_devices()[0]
ctx = cl.Context([device])
queue = cl.CommandQueue(ctx)
mf = cl.mem_flags

# -------------------------------
# OpenCL Kernel
# -------------------------------
kernel_source = """
__kernel void lattice_evolve(
    __global float *lattice,
    __global float *workspace,
    const float threshold,
    const int slots_per_instance
) {
    int gid = get_global_id(0);
    int instance = gid / slots_per_instance;
    int slot = gid % slots_per_instance;

    float val = workspace[gid];
    // Simple binary discretization
    lattice[gid] = val >= threshold ? 1.0f : 0.0f;

    // Superposition effect: small cross-instance averaging
    if(instance > 0) {
        lattice[gid] = 0.5f * (lattice[gid] + lattice[gid - slots_per_instance]);
    }
}
"""

program = cl.Program(ctx, kernel_source).build()
kernel = program.lattice_evolve

# -------------------------------
# Allocate buffers
# -------------------------------
total_slots = INSTANCES * SLOTS_PER_INSTANCE
lattice_host = np.zeros(total_slots, dtype=np.float32)
lattice_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=lattice_host)

# Each instance gets independent slice
workspace_host = np.zeros(total_slots, dtype=np.float32)
for i in range(INSTANCES):
    workspace_host[i*SLOTS_PER_INSTANCE:(i+1)*SLOTS_PER_INSTANCE] = \
        np.random.uniform(low=SQRT_PHI, high=2*SQRT_PHI, size=SLOTS_PER_INSTANCE).astype(np.float32)
workspace_buf = cl.Buffer(ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=workspace_host)

# -------------------------------
# Set kernel args
# -------------------------------
kernel.set_args(
    lattice_buf,
    workspace_buf,
    np.float32(SQRT_PHI),
    np.int32(SLOTS_PER_INSTANCE)
)

# -------------------------------
# Evolution loop
# -------------------------------
for tick in range(EVOLUTION_TICKS):
    cl.enqueue_nd_range_kernel(queue, kernel, (total_slots,), None)

    if tick % 20 == 0:
        cl.enqueue_copy(queue, lattice_host, lattice_buf)
        # Console: one line per instance
        for i in range(INSTANCES):
            slice_start = i * SLOTS_PER_INSTANCE
            slice_end = slice_start + SLOTS_PER_INSTANCE
            console_out = ''.join(['#' if val >= 1 else '.' for val in lattice_host[slice_start:slice_end]])
            print(f"[Tick {tick}] Instance {i+1}: {console_out}")

# -------------------------------
# Final aggregated lattice
# -------------------------------
print("\nHDGL-supervised lattice complete (aggregated per instance):")
for i in range(INSTANCES):
    slice_start = i * SLOTS_PER_INSTANCE
    slice_end = slice_start + SLOTS_PER_INSTANCE
    hex_repr = hex(int(''.join(['1' if val >= 1 else '0' for val in lattice_host[slice_start:slice_end]]), 2))
    print(f"Instance {i+1} hex: {hex_repr}")
